%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Run the diffusion forecast and Monte-Carlo 
% for Example 6.8
% Created by John Harlim
% Last edited: March 26, 2018
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear all

load diffusionbasisvbdmauto2dv4
load truth2d

nvars =1000;         % number of eigenfunctions
fsteps =400;       % forecast step
M = size(x,1);              % dimension of the observable
N = length(basis);          % length of training data;
shifts = 1;  

tdata = x(1:M,1:N)';                % training data
vinit = x(1:M,2*N)';                % initial condition at time 2*N

%%% normalizer <1,varphi>_peq
normalizer = mean(basis(:,1:nvars).*repmat(peq./qest,1,nvars));
%%% mean coefficient <x,varphi>_peq
A1 = (repmat(peq./qest,1,M).*tdata)'*basis(:,1:nvars)/N;
A2 = (repmat(peq./qest,1,M).*tdata.^2)'*basis(:,1:nvars)/N;
A3 = (repmat(peq./qest,1,M).*tdata.^3)'*basis(:,1:nvars)/N;
A4 = (repmat(peq./qest,1,M).*tdata.^4)'*basis(:,1:nvars)/N;

%%% Matrix A
ForwardOp = (basis(1+shifts:N,1:nvars))'*(basis(1:N-shifts,1:nvars).*repmat(peq(1:N-shifts)./qest(1:N-shifts),1,nvars))/(N-shifts);
p = zeros(N,fsteps+1);
Ex1 = zeros(M,fsteps+1);
Ex2 = zeros(M,fsteps+1);
Ex3 = zeros(M,fsteps+1);
Ex4 = zeros(M,fsteps+1);

% Gaussian initial condition
dx = tdata - repmat(vinit,N,1);
var_eq = 0.05;
p(:,1) = exp(-sum(dx.^2,2)/(2*var_eq));
Z = mean(p(:,1)./qest);
p(:,1) = p(:,1)/Z;

% project initial condition
c = (basis(:,1:nvars))'*(p(:,1)./qest)/N;
c = c/(normalizer*c);

% initial statistics
Ex1(:,1)= A1*c;
Ex2(:,1) = A2*c;    
Ex3(:,1)= A3*c;
Ex4(:,1) = A4*c;    

% diffusion forecast
for i=1:fsteps
    c = ForwardOp*c;
    c = c/(normalizer*c);
    p(:,i+1) = (repmat(peq,1,nvars).*basis(:,1:nvars))*c;
    Ex1(:,i+1)= A1*c;
    Ex2(:,i+1) = A2*c;    
    Ex3(:,i+1)= A3*c;
    Ex4(:,i+1) = A4*c;    
end

% Monte-Carlo
K = 100000;
x0 = repmat(vinit',1,K)+sqrt(var_eq)*randn(M,K);
L = [0 1; 0 -1];
B = [.5 -.5];
S = [1 1/4;1/4 1];
Shalf= sqrtm(S);
sigma = 1;
d = 1/2;
sdt = sqrt(dt);

EEx1 = zeros(M,fsteps+1);
EEx2 = zeros(M,fsteps+1);
EEx3 = zeros(M,fsteps+1);
EEx4 = zeros(M,fsteps+1);
xens = zeros(K,M,fsteps+1);

EEx1(:,1) = mean(x0,2)';
EEx2(:,1) = mean(x0.^2,2)';
EEx3(:,1) = mean(x0.^3,2)';
EEx4(:,1) = mean(x0.^4,2)';
xens(:,:,1) = x0';

for j=1:fsteps

    x1(1,:) = x0(1,:)+dt*B(1)*x0(1,:).*x0(2,:); 
    x1(2,:) = x0(2,:)+dt*B(2)*x0(1,:).*x0(1,:); 
    
    x1 = x1 + (L-S*d)*x1*dt + sigma*Shalf*sdt*randn(M,K);
    x0 = x1;
    
    xens(:,:,j+1) = x1';
    EEx1(:,j+1) = mean(x1,2)';
    EEx2(:,j+1) = mean(x1.^2,2)';
    EEx3(:,j+1) = mean(x1.^3,2)';
    EEx4(:,j+1) = mean(x1.^4,2)';
end

% plot evolution of moments
grey = [.7 .7 .7];
figure(1)
subplot(2,2,1)
hold on
plot([0:dt:fsteps*dt],Ex1(1,:),'k--')
plot([0:dt:fsteps*dt],EEx1(1,:),'k')
plot([0:dt:fsteps*dt],Ex1(2,:),'--','color',grey)
plot([0:dt:fsteps*dt],EEx1(2,:),'color',grey)
hold off
title('E[x](t)')
legend('u: DF','u: EnF','v: DF','v: EnF','location','southeast')

subplot(2,2,2)
hold on
plot([0:dt:fsteps*dt],Ex2(1,:),'k--')
plot([0:dt:fsteps*dt],EEx2(1,:),'k')
plot([0:dt:fsteps*dt],Ex2(2,:),'--','color',grey)
plot([0:dt:fsteps*dt],EEx2(2,:),'color',grey)
hold off
title('E[x^2](t)')

subplot(2,2,3)
hold on
plot([0:dt:fsteps*dt],Ex3(1,:),'k--')
plot([0:dt:fsteps*dt],EEx3(1,:),'k')
plot([0:dt:fsteps*dt],Ex3(2,:),'--','color',grey)
plot([0:dt:fsteps*dt],EEx3(2,:),'color',grey)
hold off
title('E[x^3](t)')
xlabel('t')

subplot(2,2,4)
hold on
plot([0:dt:fsteps*dt],Ex4(1,:),'k--')
plot([0:dt:fsteps*dt],EEx4(1,:),'k')
plot([0:dt:fsteps*dt],Ex4(2,:),'--','color',grey)
plot([0:dt:fsteps*dt],EEx4(2,:),'color',grey)
hold off
title('E[x^4](t)')
xlabel('t')


%print -depsc -r100 moments.eps

% plot evolution of the densities at varous times
in = [1 26 51 101 201];
figure(2)
for j=1:5
subplot(5,2,2*(j-1)+1)
scatter(tdata(1:20:end,1),tdata(1:20:end,2),2,p(1:20:end,in(j)))
axis([-4 4 -4 4])
caxis([0 .25])
title(['DF: t=',num2str((in(j)-1)*dt)])

if (j==5)
    xlabel('u')
end
ylabel('v')

subplot(5,2,2*j)
plot(xens(1:2000,1,in(j)),xens(1:2000,2,in(j)),'k.')
axis([-4 4 -4 4])
title(['Ens Fcst: t=',num2str((in(j)-1)*dt)])

if (j==5)
    xlabel('u')
end

end
%orient tall, print -depsc -r100 probdensities.eps
